{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SageMaker PySpark K-Means Clustering MNIST Example\n",
    "\n",
    "1. [Introduction](#Introduction)\n",
    "2. [Setup](#Setup)\n",
    "3. [Loading the Data](#Loading-the-Data)\n",
    "4. [Training with K-Means and Hosting a Model](#Training-with-K-Means-and-Hosting-a-Model)\n",
    "5. [Inference](#Inference)\n",
    "8. [Re-using existing endpoints or models to create a SageMakerModel](#Re-using-existing-endpoints-or-models-to-create-SageMakerModel)\n",
    "9. [Clean-up](#Clean-up)\n",
    "10. [More on SageMaker Spark](#More-on-SageMaker-Spark)\n",
    "\n",
    "## Introduction\n",
    "This notebook will show how to cluster handwritten digits through the SageMaker PySpark library. \n",
    "\n",
    "We will manipulate data through Spark using a SparkSession, and then use the SageMaker Spark library to interact with SageMaker for training and inference. \n",
    "We will first train on SageMaker using K-Means clustering on the MNIST dataset. Then, we will see how to re-use models from existing endpoints and from a model stored on S3 in order to only run inference. \n",
    "\n",
    "You can visit SageMaker Spark's GitHub repository at https://github.com/aws/sagemaker-spark to learn more about SageMaker Spark.\n",
    "\n",
    "This notebook was created and tested on an ml.m4.xlarge notebook instance."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "First, we import the necessary modules and create the `SparkSession` with the SageMaker-Spark dependencies attached. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import boto3\n",
    "\n",
    "from pyspark import SparkContext, SparkConf\n",
    "from pyspark.sql import SparkSession\n",
    "\n",
    "import sagemaker\n",
    "from sagemaker import get_execution_role\n",
    "import sagemaker_pyspark\n",
    "\n",
    "role = get_execution_role()\n",
    "\n",
    "# Configure Spark to use the SageMaker Spark dependency jars\n",
    "jars = sagemaker_pyspark.classpath_jars()\n",
    "\n",
    "classpath = \":\".join(sagemaker_pyspark.classpath_jars())\n",
    "\n",
    "# See the SageMaker Spark Github to learn how to connect to EMR from a notebook instance\n",
    "spark = SparkSession.builder.config(\"spark.driver.extraClassPath\", classpath)\\\n",
    "    .master(\"local[*]\").getOrCreate()\n",
    "    \n",
    "spark"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the Data\n",
    "\n",
    "Now, we load the MNIST dataset into a Spark Dataframe, which dataset is available in LibSVM format at\n",
    "\n",
    "`s3://sagemaker-sample-data-[region]/spark/mnist/`\n",
    "\n",
    "where `[region]` is replaced with a supported AWS region, such as us-east-1.\n",
    "\n",
    "In order to train and make inferences our input DataFrame must have a column of Doubles (named \"label\" by default) and a column of Vectors of Doubles (named \"features\" by default).\n",
    "\n",
    "Spark's LibSVM DataFrameReader loads a DataFrame already suitable for training and inference.\n",
    "\n",
    "Here, we load into a DataFrame in the SparkSession running on the local Notebook Instance, but you can connect your Notebook Instance to a remote Spark cluster for heavier workloads. Starting from EMR 5.11.0, SageMaker Spark is pre-installed on EMR Spark clusters. For more on connecting your SageMaker Notebook Instance to a remote EMR cluster, please see [this blog post](https://aws.amazon.com/blogs/machine-learning/build-amazon-sagemaker-notebooks-backed-by-spark-in-amazon-emr/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "\n",
    "region = boto3.Session().region_name\n",
    "spark._jsc.hadoopConfiguration().set('fs.s3a.endpoint', 's3.{}.amazonaws.com'.format(region))\n",
    "\n",
    "trainingData = spark.read.format('libsvm')\\\n",
    "    .option('numFeatures', '784')\\\n",
    "    .load('s3a://sagemaker-sample-data-{}/spark/mnist/train/'.format(region))\n",
    "\n",
    "testData = spark.read.format('libsvm')\\\n",
    "    .option('numFeatures', '784')\\\n",
    "    .load('s3a://sagemaker-sample-data-{}/spark/mnist/test/'.format(region))\n",
    "\n",
    "trainingData.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MNIST images are 28x28, resulting in 784 pixels. The dataset consists of images of digits going from 0 to 9, representing 10 classes. \n",
    "\n",
    "In each row:\n",
    "* The `label` column identifies the image's label. For example, if the image of the handwritten number is the digit 5, the label value is 5.\n",
    "* The `features` column stores a vector (`org.apache.spark.ml.linalg.Vector`) of `Double` values. The length of the vector is 784, as each image consists of 784 pixels. Those pixels are the features we will use. \n",
    "\n",
    "\n",
    "\n",
    "As we are interested in clustering the images of digits, the number of pixels represents the feature vector, while the number of classes represents the number of clusters we want to find. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training with K-Means and Hosting a Model\n",
    "Now we create a KMeansSageMakerEstimator, which uses the KMeans Amazon SageMaker Algorithm to train on our input data, and uses the KMeans Amazon SageMaker model image to host our model.\n",
    "\n",
    "Calling fit() on this estimator will train our model on Amazon SageMaker, and then create an Amazon SageMaker Endpoint to host our model.\n",
    "\n",
    "We can then use the SageMakerModel returned by this call to fit() to transform Dataframes using our hosted model.\n",
    "\n",
    "The following cell runs a training job and creates an endpoint to host the resulting model, so this cell can take up to twenty minutes to complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker_pyspark import IAMRole\n",
    "from sagemaker_pyspark.algorithms import KMeansSageMakerEstimator\n",
    "from sagemaker_pyspark import RandomNamePolicyFactory\n",
    "\n",
    "# Create K-Means Estimator\n",
    "kmeans_estimator = KMeansSageMakerEstimator(\n",
    "    sagemakerRole = IAMRole(role),\n",
    "    trainingInstanceType = 'ml.m4.xlarge', # Instance type to train K-means on SageMaker\n",
    "    trainingInstanceCount = 1,\n",
    "    endpointInstanceType = 'ml.t2.large', # Instance type to serve model (endpoint) for inference\n",
    "    endpointInitialInstanceCount = 1,\n",
    "    namePolicyFactory = RandomNamePolicyFactory(\"sparksm-1a-\")) # All the resources created are prefixed with sparksm-1\n",
    "\n",
    "# Set parameters for K-Means\n",
    "kmeans_estimator.setFeatureDim(784)\n",
    "kmeans_estimator.setK(10)\n",
    "\n",
    "# Train\n",
    "initialModel = kmeans_estimator.fit(trainingData)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To put this `KMeansSageMakerEstimator` back into context, let's look at the below architecture that shows what actually runs on the notebook instance and on SageMaker."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Hey](img/sagemaker-spark-kmeans-architecture.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll need the name of the SageMaker endpoint hosting the K-Means model later on. This information can be accessed directly within the `SageMakerModel`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initialModelEndpointName = initialModel.endpointName\n",
    "print(initialModelEndpointName)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference\n",
    "\n",
    "Now we transform our DataFrame.\n",
    "To do this, we serialize each row's \"features\" Vector of Doubles into a Protobuf format for inference against the Amazon SageMaker Endpoint. We deserialize the Protobuf responses back into our DataFrame. This serialization and deserialization is handled automatically by the `transform()` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run inference on the test data and show some results\n",
    "transformedData = initialModel.transform(testData)\n",
    "\n",
    "transformedData.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How well did the algorithm perform? Let us display the digits from each of the clusters and manually inspect the results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from pyspark.sql.types import DoubleType\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import string\n",
    "\n",
    "# Helper function to display a digit\n",
    "def showDigit(img, caption='', xlabel='', subplot=None):\n",
    "    if subplot==None:\n",
    "        _,(subplot)=plt.subplots(1,1)\n",
    "    imgr=img.reshape((28,28))\n",
    "    subplot.axes.get_xaxis().set_ticks([])\n",
    "    subplot.axes.get_yaxis().set_ticks([])\n",
    "    plt.title(caption)\n",
    "    plt.xlabel(xlabel)\n",
    "    subplot.imshow(imgr, cmap='gray')\n",
    "    \n",
    "def displayClusters(data):\n",
    "    images = np.array(data.select(\"features\").cache().take(250))\n",
    "    clusters = data.select(\"closest_cluster\").cache().take(250)\n",
    "\n",
    "    for cluster in range(10):\n",
    "        print('\\n\\n\\nCluster {}:'.format(string.ascii_uppercase[cluster]))\n",
    "        digits = [ img for l, img in zip(clusters, images) if int(l.closest_cluster) == cluster ]\n",
    "        height=((len(digits)-1)//5)+1\n",
    "        width=5\n",
    "        plt.rcParams[\"figure.figsize\"] = (width,height)\n",
    "        _, subplots = plt.subplots(height, width)\n",
    "        subplots=np.ndarray.flatten(subplots)\n",
    "        for subplot, image in zip(subplots, digits):\n",
    "            showDigit(image, subplot=subplot)\n",
    "        for subplot in subplots[len(digits):]:\n",
    "            subplot.axis('off')\n",
    "\n",
    "        plt.show()\n",
    "        \n",
    "displayClusters(transformedData)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we've seen how to use Spark to load data and SageMaker to train and infer on it, we will look into creating pipelines consisting of multiple algorithms, both from SageMaker-provided algorithms as well as from Spark MLlib. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Re-using existing endpoints or models to create `SageMakerModel`\n",
    "\n",
    "SageMaker Spark supports connecting a `SageMakerModel` to an existing SageMaker endpoint, or to an Endpoint created by reference to model data in S3, or to a previously completed Training Job.\n",
    "\n",
    "This allows you to use SageMaker Spark just for model hosting and inference on Spark-scale DataFrames without running a new Training Job.\n",
    "\n",
    "### Endpoint re-use\n",
    "\n",
    "Here we will connect to the initial endpoint we created by using it's unique name. The endpoint name can either be retrieved by the console or in in the `endpointName` parameter of the model you created. In our case, we saved this early on in a variable by accessing the parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ENDPOINT_NAME = initialModelEndpointName\n",
    "print(ENDPOINT_NAME)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once you have the name of the endpoint, we need to make sure that no endpoint will be created as we are attaching to an existing endpoint. This is done using `endpointCreationPolicy` field with a value of `EndpointCreationPolicy.DO_NOT_CREATE`. As we are using an endpoint serving a K-Means model, we also need to use the `KMeansProtobufResponseRowDeserializer` so that the output of the endpoint on SageMaker will be deserialized in the right way and passed on back to Spark in a DataFrame with the right columns. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker_pyspark import SageMakerModel\n",
    "from sagemaker_pyspark import EndpointCreationPolicy\n",
    "from sagemaker_pyspark.transformation.serializers import ProtobufRequestRowSerializer\n",
    "from sagemaker_pyspark.transformation.deserializers import KMeansProtobufResponseRowDeserializer\n",
    "\n",
    "attachedModel = SageMakerModel(\n",
    "    existingEndpointName = ENDPOINT_NAME,\n",
    "    endpointCreationPolicy = EndpointCreationPolicy.DO_NOT_CREATE,\n",
    "    endpointInstanceType = None, # Required\n",
    "    endpointInitialInstanceCount = None, # Required\n",
    "    requestRowSerializer = ProtobufRequestRowSerializer(featuresColumnName = \"features\"), # Optional: already default value\n",
    "    responseRowDeserializer = KMeansProtobufResponseRowDeserializer( # Optional: already default values\n",
    "      distance_to_cluster_column_name = \"distance_to_cluster\",\n",
    "      closest_cluster_column_name = \"closest_cluster\")\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "As the data we are passing through the model is using the default columns naming for both the input to the model (`features`) and for the ouput of the model (`distance_to_cluster_column_name` and `closest_cluster_column_name`), we do not need to specify the names of the columns in the serializer and deserializer. If your column naming is different, it's possible to define the name of the columns as shown above in the `requestRowSerializer` and `responseRowDeserializer`. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is also possible to use the `SageMakerModel.fromEndpoint` method to perform the same as above. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformedData2 = attachedModel.transform(testData)\n",
    "transformedData2.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create model and endpoint from model data\n",
    "\n",
    "You can create a SageMakerModel and an Endpoint by referring directly to your model data in S3. To do this, you need the path to where the model is saved (in our case on S3), as well as the role and the inference image to use. In our case, we use the model data from the initial model, consisting of a simple K-Means model. We can retrieve the necessary information from the model variable, or through the console. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker_pyspark import S3DataPath\n",
    "\n",
    "MODEL_S3_PATH = S3DataPath(initialModel.modelPath.bucket, initialModel.modelPath.objectPath)\n",
    "MODEL_ROLE_ARN = initialModel.modelExecutionRoleARN\n",
    "MODEL_IMAGE_PATH = initialModel.modelImage\n",
    "\n",
    "print(MODEL_S3_PATH.bucket + MODEL_S3_PATH.objectPath)\n",
    "print(MODEL_ROLE_ARN)\n",
    "print(MODEL_IMAGE_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similar to how we created a model from a running endpoint, we specify the model data information using `modelPath`, `modelExecutionRoleARN`, `modelImage`. This method is more akin to creating a `SageMakerEstimator`, where among others you specify the endpoint information. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker_pyspark import RandomNamePolicy\n",
    "\n",
    "retrievedModel = SageMakerModel(\n",
    "    modelPath = MODEL_S3_PATH,\n",
    "    modelExecutionRoleARN = MODEL_ROLE_ARN,\n",
    "    modelImage = MODEL_IMAGE_PATH,\n",
    "    endpointInstanceType = \"ml.t2.medium\",\n",
    "    endpointInitialInstanceCount = 1,\n",
    "    requestRowSerializer = ProtobufRequestRowSerializer(), \n",
    "    responseRowDeserializer = KMeansProtobufResponseRowDeserializer(),\n",
    "    namePolicy = RandomNamePolicy(\"sparksm-1b-\"), \n",
    "    endpointCreationPolicy = EndpointCreationPolicy.CREATE_ON_TRANSFORM\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is also possible to use the `SageMakerModel.fromModelS3Path` method that takes the same parameters and produces the same model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformedData3 = retrievedModel.transform(testData)\n",
    "transformedData3.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create model and endpoint from job training data\n",
    "\n",
    "You can create a SageMakerModel and an Endpoint by referring to a previously-completed training job. Only difference with the model data from S3 is that instead of providing the model data, you provide the `trainingJobName`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAINING_JOB_NAME = \"<YOUR_TRAINING_JOB_NAME>\"\n",
    "MODEL_ROLE_ARN = initialModel.modelExecutionRoleARN\n",
    "MODEL_IMAGE_PATH = initialModel.modelImage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelFromJob = SageMakerModel.fromTrainingJob(\n",
    "    trainingJobName = TRAINING_JOB_NAME,\n",
    "    modelExecutionRoleARN = MODEL_ROLE_ARN,\n",
    "    modelImage = MODEL_IMAGE_PATH,\n",
    "    endpointInstanceType = \"ml.t2.medium\",\n",
    "    endpointInitialInstanceCount = 1,\n",
    "    requestRowSerializer = ProtobufRequestRowSerializer(), \n",
    "    responseRowDeserializer = KMeansProtobufResponseRowDeserializer(),\n",
    "    namePolicy = RandomNamePolicy(\"sparksm-1c-\"),\n",
    "    endpointCreationPolicy = EndpointCreationPolicy.CREATE_ON_TRANSFORM\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformedData4 = modelFromJob.transform(testData)\n",
    "transformedData4.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean-up"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we don't need to make any more inferences, now we delete the resources (endpoints, models, configurations, etc):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Delete the resources\n",
    "from sagemaker_pyspark import SageMakerResourceCleanup\n",
    "\n",
    "def cleanUp(model):\n",
    "    resource_cleanup = SageMakerResourceCleanup(model.sagemakerClient)\n",
    "    resource_cleanup.deleteResources(model.getCreatedResources())\n",
    "\n",
    "# Don't forget to include any models or pipeline models that you created in the notebook\n",
    "models = [initialModel, retrievedModel, modelFromJob]\n",
    "\n",
    "# Delete regular SageMakerModels\n",
    "for m in models:\n",
    "    cleanUp(m)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## More on SageMaker Spark\n",
    "\n",
    "The SageMaker Spark Github repository has more about SageMaker Spark, including how to use SageMaker Spark using the Scala SDK: https://github.com/aws/sagemaker-spark\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
